
### Essential Functions for Doubly robust estimating equation, based on Xiang's document 
if (l0.type=="discrete"){
    nb = 32
    if (gop.type == "incorrect"){
      nb=64
    }
}else{
    nb = n
}
vec.1 <- rep(1,nb)
vec.0 <- rep(0,nb)


func.DataGen <- function(pars, data,cnt){
    p.alpha   = length(alpha.true);   p.beta = length(beta.true); p.gamma = length(gamma.true)
    
    l0 = data[,c("l0.1","l0.2")]
    a0 = data[,"a0"];  l1 = data[,"l1"];  a1 = data[,"a1"];   y = data[,"y"]
    nb = nrow(data)  
    
    # Test using the following pars
    # pars = rep(0,p.alpha+p.beta+p.gamma)
    # pars = c(c(alpha.true), c(beta.true), c(gamma.true))
    
    alpha = matrix(pars[1:p.alpha],ncol=2)
    beta  = matrix(pars[(p.alpha+1):(p.alpha+p.beta)],ncol=2)
    gamma = matrix(pars[(p.alpha+p.beta+1):(p.alpha+p.beta+p.gamma)],nrow=2)
    rownames(gamma) = c("a0","a1")
    
    theta  = exp(l0 %*% t(alpha))
    phi123 = exp(l0 %*% t(beta[1:3,]))
    phi45  = expit(l0 %*% t(beta[4:5,]))
    phi    = cbind(phi123, phi45)
    p.y.cond = t(sapply(1:nb, function(i) GetProb(theta[i,],phi[i,])))
    colnames(p.y.cond) = c("p000","p001","p010","p011","p100",
                           "p101","p110","p111")
    a0l1a1 = 4*a0 + 2*l1 + a1 + 1
    p.y = sapply(1:nb, function(i) p.y.cond[i,a0l1a1[i]])
    p.l1 = sapply(1:nb, function(i) phi45[i,2 - a0[i]])
    
    # Nuisance functions
    p.a1 = c(expit(cbind(l1,a0,l0) %*% gamma["a1",]))
    p.a0 = c(expit(l0 %*% gamma["a0",1:2]))
    
    k = 2 + 2*(1-a0) + (1-l1)
    
    return(list(k = k, p.a0 = p.a0, 
                p.l1 = p.l1, p.a1 = p.a1,
                p.y = p.y, p.y.cond=p.y.cond, 
                theta = theta, phi45=phi45))
}

if (y.hat.type == "saturated" | y.hat.type == "MLE"){
  # Using Saturated Model
  y.hat.func <- function(l0,a0,l1,a1,y.hat){
    # Here y.hat is a vector of p.y
    if (class(l0)=="numeric"){
      idx <- (l0[2]*1)+(a0*2)+(l1*4)+(a1*8) + 1
    }else if (class(l0)=="matrix"){
      idx <- (l0[,2]*1)+(a0*2)+(l1*4)+(a1*8) + 1
    }
    return(y.hat[idx])
  }
  # different function for continuous L0
  if (l0.type == "continuous"){
    y.hat.func <- function(l0,a0,l1,a1,y.hat){
      # Here y.hat is a matrix of p.y.cond
      if (class(l0)=="numeric"){
        idx <- 4*a0 + 2*l1 + a1  + 1
      }else if (class(l0)=="matrix"){
        idx <- 4*a0 + 2*l1 + a1  + 1
      }
      return(y.hat[cbind(seq_len(nrow(y.hat)), idx)])
    }
  }
 
}else if (y.hat.type == "logit"|y.hat.type == "diy"){
  # Using GLM
  y.hat.func <- function(l0,a0,l1,a1,y.hat){
    if (class(l0) == "numeric"){
      res <- expit(c(l1,a0,l0,a1) %*% y.hat)
    }else{
      res <- c(expit(cbind(l1,a0,l0,a1) %*% y.hat))
    }
    return(res)
  }
}

rm.l0.tilde <- function(data, cnt){
  idx.rm.l0.tilde <- (data[,"l0.2.tilde"] == 0)
  data.rm.l0.tilde <- data[idx.rm.l0.tilde, c(1:2,4:7)]
  cnt.rm.l0.tilde <- cnt[idx.rm.l0.tilde] + cnt[!idx.rm.l0.tilde]
  return(list(data=data.rm.l0.tilde, cnt=cnt.rm.l0.tilde))
}


p.y.func <- function(a0,l1,a1, p.y.cond, nb){
    a0l1a1 = 4*a0 + 2*l1 + a1 + 1
    return(p.y.cond[cbind(1:nb, a0l1a1)])
    # return(sapply(1:nb, function(i) p.y.cond[i,a0l1a1[i]]))
}

p.a1.func <- function(l1,a0,l0, gamma){
    return(c(expit(cbind(l1,a0,l0) %*% gamma["a1",])))
}

# p.l1.func <- function(a0, phi45, nb){
#     return(sapply(1:nb, function(i) phi45[i,2 - a0[i]]))
# }
# Default: if (l1.type == "correct")
p.l1.func<- function(a0,l0,beta, nb,l1.hat="holder"){
  if (class(l0) == "numeric"){
    phi4 <- l0 %*% beta[4,]
    phi5 <- l0 %*% beta[5,]
    phi45 <- expit(c(phi4,phi5))
    res <- phi45[2 - a0]
  }else if (class(l0)=="matrix"){
    phi4 <- l0 %*% beta[4,]
    phi5 <- l0 %*% beta[5,]
    phi45 <- expit(cbind(phi4,phi5))
    # res <- sapply(1:nb, function(i) phi45[i,2 - a0[i]])
    res <- phi45[cbind(1:nrow(phi45), 2-a0)]
  }
  return(res)
}
if (l1.type == "diy"){
  p.l1.func<- function(a0,l0,beta, nb,l1.hat="holder"){
    if (class(l0) == "numeric"){
      res <- 0.1 + 0.8 * a0
    }else if (class(l0)=="matrix"){
      res <- sapply(1:nb, function(i) {0.1 + 0.8 * a0[i]})
    }
    return(res)
  }
}else if (l1.type == "incorrect"){
  p.l1.func<- function(a0,l0,beta, nb,l1.hat="holder"){
    if (class(l0) == "numeric"){
      res <- expit(c(l0, a0) %*% l1.hat)
    }else if (class(l0)=="matrix"){
      # res <- sapply(1:nb, function(i) {expit(c(l0[i,], a0[i]) %*% l1.hat)})
      res <- c(expit(cbind(l0, a0) %*% l1.hat))
    }
    return(res)
  }
}





p.a0.func <- function(l0, gamma){
  return(c(expit(cbind(l0) %*% gamma["a0",1:2])))
}

# g1 <- function(l1, a0, l0, p.y.cond, p.a1, theta.k.inv, gamma, nb){
#     p.y.1 <- p.y.func(a0,l1,a1=vec.1, p.y.cond, nb)
#     return(theta.k.inv * p.y.1 * p.a1.func(l1,a0,l0, gamma))
# }

if (ee.assume.true.alpha){
  g <- function(l1, a0, l0, y.hat){
    return(y.hat.func(l0,a0,l1,a1=vec.0,y.hat))
  }
  g1 <- function(l1, a0, l0, y.hat,gamma){
    return(y.hat.func(l0,a0,l1,a1=vec.0,y.hat) *  p.a1.func(l1,a0,l0, gamma))
  }
  g0 <- function(l1, a0, l0, y.hat,gamma){
    return(y.hat.func(l0,a0,l1,a1=vec.0,y.hat) *  (1-p.a1.func(l1,a0,l0, gamma)))
  }
}else{
  g1 <- function(l1, a0, l0, theta.k.inv, gamma, nb,y.hat){
    p.y.1 <- y.hat.func(l0,a0,l1,a1=vec.1,y.hat)
    return(theta.k.inv * p.y.1 * p.a1.func(l1,a0,l0, gamma))
  }
  g0 <- function(l1, a0, l0, theta.k.inv, gamma, nb,y.hat){
    p.y.0 <- y.hat.func(l0,a0,l1,a1=vec.0,y.hat)
    return(p.y.0 * (1 -  p.a1.func(l1,a0,l0, gamma)))
  }
  g <- function(l1,a0,l0, p.y.cond, p.a1,theta.k.inv, gamma, nb,y.hat){
    return(g1(l1, a0, l0,theta.k.inv, gamma, nb,y.hat) +
             g0(l1, a0, l0, theta.k.inv, gamma, nb,y.hat) )
  }
}





# h <- function(a0, l0, theta..inv.a0, 
#               p.y.cond, p.a1, theta.k.inv, gamma, nb,
#               phi45){
#     g.part.1 <- g(l1=vec.1,a0,l0, p.y.cond, p.a1, theta.k.inv, gamma, nb)
#     g.part.0 <- g(l1=vec.0,a0,l0, p.y.cond, p.a1, theta.k.inv, gamma, nb)
#     # g.part.1 <- g(l1=vec.1,a0,l0,y.hat)
#     # g.part.0 <- g(l1=vec.0,a0,l0,y.hat)
#     p.l1 <- p.l1.func(a0, phi45, nb)
#     return( theta..inv.a0 * (p.l1 * g.part.1 + (1-p.l1) * g.part.0))
# }

if (ee.assume.true.alpha){
  h <- function(a0, l0, theta..inv.a0, 
                p.y.cond, p.a1, theta.k.inv, gamma, nb,
                beta, y.hat,l1.hat="holder"){
    # p.l1.func(a0,l0,beta, nb)
    p.l1 <- p.l1.func(a0=vec.0,l0,beta, nb,l1.hat)
    return( (p.l1 * y.hat.func(l0,a0=vec.0,l1=vec.1,a1=vec.0,y.hat)
             + (1-p.l1) * y.hat.func(l0,a0=vec.0,l1=vec.0,a1=vec.0,y.hat)))
  }
}else{
  h <- function(a0, l0, theta..inv.a0, 
                p.y.cond, p.a1, theta.k.inv, gamma, nb,
                beta,y.hat,l1.hat="holder"){
    # (l1,a0,l0, p.y.cond, p.a1,theta.k.inv, gamma, nb,y.hat)
    g.part.1 <- g(l1=vec.1,a0,l0,p.y.cond, p.a1,theta.k.inv, gamma, nb,y.hat)
    g.part.0 <- g(l1=vec.0,a0,l0,p.y.cond, p.a1,theta.k.inv, gamma, nb,y.hat)
    # g.part.1 <- g(l1=vec.1,a0,l0,y.hat)
    # g.part.0 <- g(l1=vec.0,a0,l0,y.hat)
    p.l1 <- p.l1.func(a0,l0,beta, nb,l1.hat)
    return( theta..inv.a0 * (p.l1 * g.part.1 + (1-p.l1) * g.part.0))
  }
}


f. <- function(a0,l0, theta..inv.a0, 
               p.y.cond, p.a1, theta.k.inv, gamma, nb,
               beta,y.hat,l1.hat="holder"){
    return(-l0 * a0 * h(a0,l0, theta..inv.a0, 
                        p.y.cond, p.a1, theta.k.inv, gamma, nb,
                        beta,y.hat,l1.hat=l1.hat))
}

fk <- function(a0, l0, theta..inv, 
               p.y.cond, p.a1, theta.k.inv, gamma, nb,
               beta,y.hat,l1.hat="holder"){
    if (ee.assume.true.alpha){
      g1.part.1 <- g1(l1=vec.1, a0, l0, y.hat,gamma)
      g1.part.0 <- g1(l1=vec.0, a0, l0, y.hat,gamma)
    }else{
      g1.part.1 <- g1(l1=vec.1,a0,l0, theta.k.inv, gamma, nb,y.hat)
      g1.part.0 <- g1(l1=vec.0,a0,l0, theta.k.inv, gamma, nb,y.hat)
    }
    # g1.part.1 <- g1(l1=vec.1,a0,l0, theta.k.inv, gamma, nb)
    # g1.part.0 <- g1(l1=vec.0,a0,l0, theta.k.inv, gamma, nb)
    p.l1 <- p.l1.func(a0,l0,beta, nb,l1.hat)
    p.l0 <- 1 - p.l1
    tmp <- array(dim=c(nb,4,2))
    tmp[,1,] <- -l0 * (a0==1) * ((theta..inv)^a0) * p.l1 * g1.part.1 # dim=(nb,2)
    tmp[,2,] <- -l0 * (a0==1) * ((theta..inv)^a0) * p.l0 * g1.part.0
    tmp[,3,] <- -l0 * (a0==0) * p.l1 * g1.part.1
    tmp[,4,] <- -l0 * (a0==0) * p.l0 * g1.part.0
    # return(-l0 * ((theta..inv)^a0) * (p.l1 * g1.part.1 + (1-p.l1) * g1.part.0))
    # return(-l0* (theta..inv.a0)^a0)
    # return(-l0*a0)
    return(tmp)
}

E.fk <- function(p.a0, l0, theta..inv, 
               p.y.cond, p.a1, theta.k.inv, gamma, nb,
               beta,y.hat,l1.hat="holder"){
  # if (ee.assume.true.alpha){
  #   g1.part.1 <- g1(l1=vec.1, a0, l0, y.hat,gamma)
  #   g1.part.0 <- g1(l1=vec.0, a0, l0, y.hat,gamma)
  # }else{
  #   g1.part.1 <- g1(l1=vec.1,a0,l0, theta.k.inv, gamma, nb,y.hat)
  #   g1.part.0 <- g1(l1=vec.0,a0,l0, theta.k.inv, gamma, nb,y.hat)
  # }
  # g1.part.1 <- g1(l1=vec.1,a0,l0, theta.k.inv, gamma, nb)
  # g1.part.0 <- g1(l1=vec.0,a0,l0, theta.k.inv, gamma, nb)
  # p.l1 <- p.l1.func(a0,l0,beta, nb,l1.hat)
  # p.l0 <- 1 - p.l1
  tmp <- array(dim=c(nb,4,2))
  tmp[,1,] <- -l0 * p.a0 * ((theta..inv)) * p.l1.func(a0=vec.1,l0,beta, nb,l1.hat) * 
    g1(l1=vec.1, a0=vec.1, l0, y.hat,gamma) # dim=(nb,2)
  tmp[,2,] <- -l0 * p.a0 * ((theta..inv)) * (1-p.l1.func(a0=vec.1,l0,beta, nb,l1.hat)) * 
    g1(l1=vec.0, a0=vec.1, l0, y.hat,gamma) # dim=(nb,2)
  tmp[,3,] <- -l0 * (1-p.a0) * p.l1.func(a0=vec.0,l0,beta, nb,l1.hat) * 
    g1(l1=vec.1, a0=vec.0, l0, y.hat,gamma)
  tmp[,4,] <- -l0 * (1-p.a0) * (1-p.l1.func(a0=vec.0,l0,beta, nb,l1.hat)) * 
    g1(l1=vec.0, a0=vec.0, l0, y.hat,gamma) # dim=(nb,2)
  # return(-l0 * ((theta..inv)^a0) * (p.l1 * g1.part.1 + (1-p.l1) * g1.part.0))
  # return(-l0* (theta..inv.a0)^a0)
  # return(-l0*a0)
  return(tmp)
}

